{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load functions from openl3/cli.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "# %load ~/dev/openl3/openl3/cli.py\n",
    "from __future__ import print_function\n",
    "import os\n",
    "import sys\n",
    "from openl3.openl3_exceptions import OpenL3Error\n",
    "from openl3 import process_file\n",
    "from argparse import ArgumentParser, RawDescriptionHelpFormatter, ArgumentTypeError\n",
    "from collections import Iterable\n",
    "from six import string_types\n",
    "\n",
    "\n",
    "def positive_float(value):\n",
    "    \"\"\"An argparse type method for accepting only positive floats\"\"\"\n",
    "    try:\n",
    "        fvalue = float(value)\n",
    "    except (ValueError, TypeError) as e:\n",
    "        raise ArgumentTypeError('Expected a positive float, error message: '\n",
    "                                '{}'.format(e))\n",
    "    if fvalue <= 0:\n",
    "        raise ArgumentTypeError('Expected a positive float')\n",
    "    return fvalue\n",
    "\n",
    "\n",
    "def get_file_list(input_list):\n",
    "    \"\"\"Get list of files from the list of inputs\"\"\"\n",
    "    if not isinstance(input_list, Iterable) or isinstance(input_list, string_types):\n",
    "        raise ArgumentTypeError('input_list must be iterable (and not string)')\n",
    "    file_list = []\n",
    "    for item in input_list:\n",
    "        if os.path.isfile(item):\n",
    "            file_list.append(os.path.abspath(item))\n",
    "        elif os.path.isdir(item):\n",
    "            for fname in os.listdir(item):\n",
    "                path = os.path.join(item, fname)\n",
    "                if os.path.isfile(path):\n",
    "                    file_list.append(path)\n",
    "        else:\n",
    "            raise OpenL3Error('Could not find {}'.format(item))\n",
    "\n",
    "    return file_list\n",
    "\n",
    "\n",
    "def run(inputs, output_dir=None, suffix=None, input_repr=\"mel256\", content_type=\"music\",\n",
    "        embedding_size=6144, center=True, hop_size=0.1, verbose=False):\n",
    "    \"\"\"\n",
    "    Computes and saves L3 embedding for given inputs.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    inputs : list of str, or str\n",
    "        File/directory path or list of file/directory paths to be processed\n",
    "    output_dir : str or None\n",
    "        Path to directory for saving output files. If None, output files will\n",
    "        be saved to the directory containing the input file.\n",
    "    suffix : str or None\n",
    "        String to be appended to the output filename, i.e. <base filename>_<suffix>.npy.\n",
    "        If None, then no suffix will be added, i.e. <base filename>.npy.\n",
    "    input_repr : \"linear\", \"mel128\", or \"mel256\"\n",
    "        Spectrogram representation used for model.\n",
    "    content_type : \"music\" or \"env\"\n",
    "        Type of content used to train embedding.\n",
    "    embedding_size : 6144 or 512\n",
    "        Embedding dimensionality.\n",
    "    center : boolean\n",
    "        If True, pads beginning of signal so timestamps correspond\n",
    "        to center of window.\n",
    "    hop_size : float\n",
    "        Hop size in seconds.\n",
    "    quiet : boolean\n",
    "        If True, suppress all non-error output to stdout\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    \"\"\"\n",
    "\n",
    "    if isinstance(inputs, string_types):\n",
    "        file_list = [inputs]\n",
    "    elif isinstance(inputs, Iterable):\n",
    "        file_list = get_file_list(inputs)\n",
    "    else:\n",
    "        raise OpenL3Error('Invalid input: {}'.format(str(inputs)))\n",
    "\n",
    "    if len(file_list) == 0:\n",
    "        print('openl3: No WAV files found in {}. Aborting.'.format(str(inputs)))\n",
    "        sys.exit(-1)\n",
    "\n",
    "    # Process all files in the arguments\n",
    "    for filepath in file_list:\n",
    "        if verbose:\n",
    "            print('openl3: Processing: {}'.format(filepath))\n",
    "        process_file(filepath,\n",
    "                     output_dir=output_dir,\n",
    "                     suffix=suffix,\n",
    "                     input_repr=input_repr,\n",
    "                     content_type=content_type,\n",
    "                     embedding_size=embedding_size,\n",
    "                     center=center,\n",
    "                     hop_size=hop_size,\n",
    "                     verbose=verbose)\n",
    "    if verbose:\n",
    "        print('openl3: Done!')\n",
    "\n",
    "\n",
    "def parse_args(args):\n",
    "    parser = ArgumentParser(sys.argv[0], description=main.__doc__,\n",
    "                            formatter_class=RawDescriptionHelpFormatter)\n",
    "\n",
    "    parser.add_argument('inputs', nargs='+',\n",
    "                        help='Path or paths to files to process, or path to '\n",
    "                             'a directory of files to process.')\n",
    "\n",
    "    parser.add_argument('--output-dir', '-o', default=None,\n",
    "                        help='Directory to save the ouptut file(s); '\n",
    "                             'if not given, the output will be '\n",
    "                             'saved to the same directory as the input WAV '\n",
    "                             'file(s).')\n",
    "\n",
    "    parser.add_argument('--suffix', '-x', default=None,\n",
    "                        help='String to append to the output filenames.'\n",
    "                             'If not provided, no suffix is added.')\n",
    "\n",
    "    parser.add_argument('--input-repr', '-i', default='mel256',\n",
    "                        choices=['linear', 'mel128', 'mel256'],\n",
    "                        help='String specifying the time-frequency input '\n",
    "                             'representation for the embedding model.')\n",
    "\n",
    "    parser.add_argument('--content-type', '-c', default='music',\n",
    "                        choices=['music', 'env'],\n",
    "                        help='Content type used to train embedding model.')\n",
    "\n",
    "    parser.add_argument('--embedding-size', '-s', type=int, default=6144,\n",
    "                        help='Embedding dimensionality.')\n",
    "\n",
    "    parser.add_argument('--no-centering', '-n', action='store_true', default=False,\n",
    "                        help='Do not pad signal; timestamps will correspond to '\n",
    "                             'the beginning of each analysis window.')\n",
    "\n",
    "    parser.add_argument('--hop-size', '-t', type=positive_float, default=0.1,\n",
    "                        help='Hop size in seconds for processing audio files.')\n",
    "\n",
    "    parser.add_argument('--quiet', '-q', action='store_true', default=False,\n",
    "                        help='Suppress all non-error messages to stdout.')\n",
    "\n",
    "    return parser.parse_args(args)\n",
    "\n",
    "\n",
    "def main():\n",
    "    \"\"\"\n",
    "    Extracts audio embeddings from models based on the Look, Listen, and Learn models (Arandjelovic and Zisserman 2017).\n",
    "    \"\"\"\n",
    "    args = parse_args(sys.argv[1:])\n",
    "\n",
    "    run(args.inputs,\n",
    "        output_dir=args.output_dir,\n",
    "        suffix=args.suffix,\n",
    "        input_repr=args.input_repr,\n",
    "        content_type=args.content_type,\n",
    "        embedding_size=args.embedding_size,\n",
    "        center=not args.no_centering,\n",
    "        hop_size=args.hop_size,\n",
    "        verbose=not args.quiet)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Set paths"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to wav file\n",
    "test_path = os.path.expanduser('~/dev/openl3/tests/')\n",
    "chirp44_path = os.path.join(test_path, 'data', 'audio', 'chirp_44k.wav')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# path to store output embeddings\n",
    "output_dir = os.path.expanduser('~/openl3_output/')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /beegfs/jtc440/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1204: calling reduce_max (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "keep_dims is deprecated, use keepdims instead\n",
      "WARNING:tensorflow:From /beegfs/jtc440/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1238: calling reduce_sum (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "keep_dims is deprecated, use keepdims instead\n",
      "WARNING:tensorflow:From /beegfs/jtc440/miniconda3/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:1255: calling reduce_prod (from tensorflow.python.ops.math_ops) with keep_dims is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "keep_dims is deprecated, use keepdims instead\n"
     ]
    }
   ],
   "source": [
    "# compute mel256/music/6144 regression embedding\n",
    "suffix=None\n",
    "input_repr='mel256'\n",
    "content_type='music'\n",
    "embedding_size=6144\n",
    "center=True\n",
    "hop_size=0.1\n",
    "verbose=False\n",
    "\n",
    "run(chirp44_path,\n",
    "    output_dir=output_dir,\n",
    "    suffix=suffix,\n",
    "    input_repr=input_repr,\n",
    "    content_type=content_type,\n",
    "    embedding_size=embedding_size,\n",
    "    center=center,\n",
    "    hop_size=hop_size,\n",
    "    verbose=verbose)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# compute linear/env/512 regression embedding\n",
    "suffix='linear'\n",
    "input_repr='linear'\n",
    "content_type='env'\n",
    "embedding_size=512\n",
    "center=False\n",
    "hop_size=0.5\n",
    "verbose=False\n",
    "\n",
    "run(chirp44_path,\n",
    "    output_dir=output_dir,\n",
    "    suffix=suffix,\n",
    "    input_repr=input_repr,\n",
    "    content_type=content_type,\n",
    "    embedding_size=embedding_size,\n",
    "    center=center,\n",
    "    hop_size=hop_size,\n",
    "    verbose=verbose)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OPTIONAL: compare to previous regression data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_emb_path = os.path.join(test_path, 'data', 'regression', 'chirp_44k.npz')\n",
    "reg_emb_linear_path = os.path.join(test_path, 'data', 'regression', 'chirp_44k_linear.npz')\n",
    "\n",
    "new_emb_path = os.path.join(output_dir, 'chirp_44k.npz')\n",
    "new_emb_linear_path = os.path.join(output_dir, 'chirp_44k_linear.npz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "reg_emb = np.load(reg_emb_path)\n",
    "reg_emb_linear = np.load(reg_emb_linear_path)\n",
    "\n",
    "new_emb = np.load(new_emb_path)\n",
    "new_emb_linear = np.load(new_emb_linear_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.allclose(reg_emb['timestamps'], new_emb['timestamps'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.allclose(reg_emb['embedding'], new_emb['embedding'], rtol=1e-05, atol=1e-06, equal_nan=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.allclose(reg_emb_linear['timestamps'], new_emb_linear['timestamps'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert np.allclose(reg_emb_linear['embedding'], new_emb_linear['embedding'], rtol=1e-05, atol=1e-06, equal_nan=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
